import torch
from Visualization import funcaverage
import copy

def TakeGradient(x, LossFunction):
    FullGradient = torch.tensor([])
    grad = torch.autograd.grad(LossFunction(x), x)
    for g in grad:
        FullGradient = torch.cat((FullGradient, g), 0)
    return FullGradient.view(1, -1)

def Covariance(x, LossFunctions):
    x = copy.deepcopy(x)
    Cov = torch.zeros(2, 2)
    for k in range(len(LossFunctions)):
        #print(torch.autograd.grad(LossFunctions[k](x), x))
        Gradient = TakeGradient(x, LossFunctions[k])
        #print("gradient is {}".format(Gradient))
        Cov += Gradient.T @ Gradient
    #print(Cov)
    Cov /= len(LossFunctions)
    BatchGradient = TakeGradient(x, funcaverage(LossFunctions))
    Cov -= BatchGradient.T @ BatchGradient
    #print("cov is {}".format(Cov))
    return Cov